/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
 
/*!
 * \file sort_v300_impl.h
 * \brief
 */
#ifndef IMPL_SORT_SORT_SORT_V300_IMPL_H
#define IMPL_SORT_SORT_SORT_V300_IMPL_H
#include "sort_pre_impl.h"

namespace AscendC {
constexpr uint32_t REGION_DATA_V300_SIZE = 8;
constexpr uint32_t SINGLE_SORT_V300_COUNT = 32;
constexpr uint32_t SORT_V300_LEN = 4;

template <typename T>
__aicore__ inline void FullSortInnerLoop(const LocalTensor<T> &dstLocal, const LocalTensor<T> &tmpLocal,
    const uint32_t baseOffset, const uint16_t singleMergeTmpElementCount, const int32_t mergeTmpRepeatTimes)
{
    if (mergeTmpRepeatTimes <= 0) {
        return;
    }
    MrgSortSrcList sortList =
        MrgSortSrcList(tmpLocal[0], tmpLocal[baseOffset], tmpLocal[2 * baseOffset], tmpLocal[3 * baseOffset]);
    const uint16_t elementCountList[SORT_V300_LEN] = {singleMergeTmpElementCount, singleMergeTmpElementCount,
        singleMergeTmpElementCount, singleMergeTmpElementCount};
    uint32_t sortedNum[SORT_V300_LEN];
    MrgSort<T>(dstLocal, sortList, elementCountList, sortedNum, 0b1111, mergeTmpRepeatTimes);
}

template <typename T>
__aicore__ inline void FullSortInnerLoopTail(const LocalTensor<T> &dstLocal, const LocalTensor<T> &tmpLocal,
    const uint32_t baseOffset, const uint16_t singleMergeTmpElementCount, const uint32_t elementCountTail,
    const int32_t mergeTmpRepeatTimes, int32_t mergeTmpTailQueNum)
{
    ComSortInnerLoopTail(dstLocal, tmpLocal, baseOffset, singleMergeTmpElementCount, elementCountTail,
        mergeTmpRepeatTimes, mergeTmpTailQueNum);
}

__aicore__ inline uint32_t GetFullSortInnerLoopTimes(const int32_t repeatTimes)
{
    uint32_t loopi = 0;
    int32_t queNum = repeatTimes;
    while (queNum > 1) {
        queNum = Ceil(queNum, SORT_V300_LEN);
        loopi++;
    }
    return loopi;
}

template <typename T>
__aicore__ inline void DoFullSort(const LocalTensor<T> &dstLocal, const LocalTensor<T> &concatLocal,
    const LocalTensor<uint32_t> &indexLocal, LocalTensor<T> &tmpLocal, const int32_t repeatTimes)
{
    uint32_t elementCount = concatLocal.GetSize();
    uint32_t singleMergeElementCount = SINGLE_SORT_V300_COUNT;
    uint32_t loopi = GetFullSortInnerLoopTimes(repeatTimes);
    uint16_t singleMergeTmpElementCount = singleMergeElementCount;
    uint32_t srcLocalElementCount = repeatTimes * singleMergeElementCount;
    uint32_t dstLocalElementCount = srcLocalElementCount * REGION_DATA_V300_SIZE / sizeof(T);
    int32_t mergeTmpTotalQueNum = repeatTimes;
    int32_t mergeTmpTailQueNum = repeatTimes % SORT_V300_LEN;
    int32_t mergeTmpQueNum = mergeTmpTotalQueNum - mergeTmpTailQueNum;
    int32_t mergeTmpRepeatTimes = repeatTimes / SORT_V300_LEN;
    DataCopy(tmpLocal, dstLocal, dstLocalElementCount);
    PipeBarrier<PIPE_V>();
    for (int i = 0; i < loopi; i++) {
        uint32_t baseOffset;
        baseOffset = singleMergeTmpElementCount * REGION_DATA_V300_SIZE / sizeof(T);
        FullSortInnerLoop(dstLocal, tmpLocal, baseOffset, singleMergeTmpElementCount, mergeTmpRepeatTimes);
        PipeBarrier<PIPE_V>();
        uint16_t elementCountTail = srcLocalElementCount % singleMergeTmpElementCount ?
            srcLocalElementCount % singleMergeTmpElementCount :
            singleMergeTmpElementCount;
        FullSortInnerLoopTail(dstLocal, tmpLocal, baseOffset, singleMergeTmpElementCount, elementCountTail,
            mergeTmpRepeatTimes, mergeTmpTailQueNum);
        PipeBarrier<PIPE_V>();
        DataCopy(tmpLocal, dstLocal, dstLocalElementCount);
        PipeBarrier<PIPE_V>();
        singleMergeTmpElementCount *= SORT_V300_LEN;
        mergeTmpTotalQueNum = mergeTmpTotalQueNum % SORT_V300_LEN ?
            mergeTmpTotalQueNum / SORT_V300_LEN + 1 :
            mergeTmpTotalQueNum / SORT_V300_LEN;
        mergeTmpTailQueNum = mergeTmpTotalQueNum % SORT_V300_LEN;
        if (mergeTmpTailQueNum == 0 && elementCountTail != singleMergeTmpElementCount) {
            mergeTmpTailQueNum = SORT_V300_LEN;
        }
        mergeTmpQueNum = mergeTmpTotalQueNum - mergeTmpTailQueNum;
        mergeTmpRepeatTimes = mergeTmpQueNum / SORT_V300_LEN;
    }
}

} // namespace AscendC
#endif // IMPL_SORT_SORT_SORT_V300_IMPL_H
